{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkdnLiKk71g-"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0asMuNro71hA"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/custom_federated_algorithm_with_tff_optimizers\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/v0.84.0/docs/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/v0.84.0/docs/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Zv28F7QLo8O"
      },
      "source": [
        "# Use TFF optimizers in custom iterative process\n",
        "\n",
        "This is an alternative to the [Build Your Own Federated Learning Algorithm](building_your_own_federated_learning_algorithm.ipynb) tutorial and the [simple_fedavg](https://github.com/tensorflow/federated/tree/main/examples/simple_fedavg) example to build a custom iterative process for the [federated averaging](https://arxiv.org/abs/1602.05629) algorithm. This tutorial will use [TFF optimizers](https://github.com/tensorflow/federated/tree/main/tensorflow_federated/python/learning/optimizers) instead of Keras optimizers.\n",
        "The TFF optimizer abstraction is desgined to be state-in-state-out to be easier to be incorporated in a TFF iterative process. The `tff.learning` APIs also accept TFF optimizers as input argument."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MnUwFbCAKB2r"
      },
      "source": [
        "## Before we start\n",
        "\n",
        "Before we start, please run the following to make sure that your environment is\n",
        "correctly setup. If you don't see a greeting, please refer to the\n",
        "[Installation](../install.md) guide for instructions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZrGitA_KnRO0"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow-federated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "HGTM6tWOLo8M"
      },
      "outputs": [],
      "source": [
        "from typing import Any\n",
        "import functools\n",
        "\n",
        "import attrs\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQ_N9XbULo8P"
      },
      "source": [
        "## Preparing data and model\n",
        "The EMNIST data processing and model are very similar to the [simple_fedavg](https://github.com/tensorflow/federated/tree/main/examples/simple_fedavg) example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Blrh8zJgLo8R"
      },
      "outputs": [],
      "source": [
        "only_digits=True\n",
        "\n",
        "# Load dataset.\n",
        "emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)\n",
        "\n",
        "# Define preprocessing functions.\n",
        "def preprocess_fn(dataset, batch_size=16):\n",
        "\n",
        "  def batch_format_fn(element):\n",
        "    return (tf.expand_dims(element['pixels'], -1), element['label'])\n",
        "\n",
        "  return dataset.batch(batch_size).map(batch_format_fn)\n",
        "\n",
        "# Preprocess and sample clients for prototyping.\n",
        "train_client_ids = sorted(emnist_train.client_ids)\n",
        "train_data = emnist_train.preprocess(preprocess_fn)\n",
        "central_test_data = preprocess_fn(\n",
        "    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))\n",
        "\n",
        "# Define model.\n",
        "def create_keras_model():\n",
        "  \"\"\"The CNN model used in https://arxiv.org/abs/1602.05629.\"\"\"\n",
        "  data_format = 'channels_last'\n",
        "  input_shape = [28, 28, 1]\n",
        "\n",
        "  max_pool = functools.partial(\n",
        "      tf.keras.layers.MaxPooling2D,\n",
        "      pool_size=(2, 2),\n",
        "      padding='same',\n",
        "      data_format=data_format)\n",
        "  conv2d = functools.partial(\n",
        "      tf.keras.layers.Conv2D,\n",
        "      kernel_size=5,\n",
        "      padding='same',\n",
        "      data_format=data_format,\n",
        "      activation=tf.nn.relu)\n",
        "\n",
        "  model = tf.keras.models.Sequential([\n",
        "      conv2d(filters=32, input_shape=input_shape),\n",
        "      max_pool(),\n",
        "      conv2d(filters=64),\n",
        "      max_pool(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
        "      tf.keras.layers.Dense(10 if only_digits else 62),\n",
        "  ])\n",
        "\n",
        "  return model\n",
        "\n",
        "# Wrap as `tff.learning.models.VariableModel`.\n",
        "def model_fn():\n",
        "  keras_model = create_keras_model()\n",
        "  return tff.learning.models.from_keras_model(\n",
        "      keras_model,\n",
        "      input_spec=central_test_data.element_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fPOWP2JjsfTk"
      },
      "source": [
        "## Custom iterative process\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "50N36Zz8qyY-"
      },
      "source": [
        "In many cases, federated algorithms have 4 main components:\n",
        "\n",
        "1. A server-to-client broadcast step.\n",
        "2. A local client update step.\n",
        "3. A client-to-server upload step.\n",
        "4. A server update step.\n",
        "\n",
        "In TFF, we generally represent federated algorithms as a [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess) (which we refer to as just an `IterativeProcess` throughout). This is a class that contains `initialize` and `next` functions. Here, `initialize` is used to initialize the server, and `next` will perform one communication round of the federated algorithm.\n",
        "\n",
        "We will introduce different components to build the federated averaging (FedAvg) algorithm, which will use an optimizer in the client update step, and another optimizer in the server update step. The core logics of client and server updates can be expressed as pure TF blocks."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxpNYucgLo8g"
      },
      "source": [
        "### TF blocks: client and server update\n",
        "\n",
        "On each client, a local `client_optimizer` is initialized and used to update the client model weights. On the server, `server_optimizer` will use the state from the *previous* round, and update the state for the next round."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "c5rHPKreLo8g"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def client_update(model, dataset, server_weights, client_optimizer):\n",
        "  \"\"\"Performs local training on the client's dataset.\"\"\"\n",
        "  # Initialize the client model with the current server weights.\n",
        "  client_weights = model.trainable_variables\n",
        "  # Assign the server weights to the client model.\n",
        "  tf.nest.map_structure(lambda x, y: x.assign(y),\n",
        "                        client_weights, server_weights)\n",
        "  # Initialize the client optimizer.\n",
        "  trainable_tensor_specs = tf.nest.map_structure(\n",
        "          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)\n",
        "  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)\n",
        "  # Use the client_optimizer to update the local model.\n",
        "  for batch in iter(dataset):\n",
        "    with tf.GradientTape() as tape:\n",
        "      # Compute a forward pass on the batch of data.\n",
        "      outputs = model.forward_pass(batch)\n",
        "    # Compute the corresponding gradient.\n",
        "    grads = tape.gradient(outputs.loss, client_weights)\n",
        "    # Apply the gradient using a client optimizer.\n",
        "    optimizer_state, updated_weights = client_optimizer.next(\n",
        "        optimizer_state, client_weights, grads)\n",
        "    tf.nest.map_structure(lambda a, b: a.assign(b),\n",
        "                          client_weights, updated_weights)\n",
        "  # Return model deltas.\n",
        "  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "rYxErLvHLo8i"
      },
      "outputs": [],
      "source": [
        "@attrs.define(eq=False, frozen=True)\n",
        "class ServerState(object):\n",
        "  trainable_weights: Any\n",
        "  optimizer_state: Any\n",
        "\n",
        "@tf.function\n",
        "def server_update(server_state, mean_model_delta, server_optimizer):\n",
        "  \"\"\"Updates the server model weights.\"\"\"\n",
        "  # Use aggregated negative model delta as pseudo gradient.\n",
        "  negative_weights_delta = tf.nest.map_structure(\n",
        "      lambda w: -1.0 * w, mean_model_delta)\n",
        "  new_optimizer_state, updated_weights = server_optimizer.next(\n",
        "      server_state.optimizer_state, server_state.trainable_weights,\n",
        "      negative_weights_delta)\n",
        "  return tff.structure.update_struct(\n",
        "      server_state,\n",
        "      trainable_weights=updated_weights,\n",
        "      optimizer_state=new_optimizer_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0zNTO7LLo84"
      },
      "source": [
        "### TFF blocks: `tff.tensorflow.computation` and `tff.federated_computation`\n",
        "\n",
        "We now use TFF for orchestration and build the iterative process for FedAvg. We have to wrap the TF blocks defined above with `tff.tensorflow.computation`, and use TFF methods `tff.federated_broadcast`, `tff.federated_map`, `tff.federated_mean` in a `tff.federated_computation` function. It is easy to use the `tff.learning.optimizers.Optimizer` APIs with `initialize` and `next` functions when defining a custom iterative process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "jJY9xUBZLo84"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "server_state_type:\n",
            " \u003c\n",
            "  trainable_weights=\u003c\n",
            "    float32[5,5,1,32],\n",
            "    float32[32],\n",
            "    float32[5,5,32,64],\n",
            "    float32[64],\n",
            "    float32[3136,512],\n",
            "    float32[512],\n",
            "    float32[512,10],\n",
            "    float32[10]\n",
            "  \u003e,\n",
            "  optimizer_state=\u003c\n",
            "    learning_rate=float32,\n",
            "    momentum=float32,\n",
            "    accumulator=\u003c\n",
            "      float32[5,5,1,32],\n",
            "      float32[32],\n",
            "      float32[5,5,32,64],\n",
            "      float32[64],\n",
            "      float32[3136,512],\n",
            "      float32[512],\n",
            "      float32[512,10],\n",
            "      float32[10]\n",
            "    \u003e\n",
            "  \u003e\n",
            "\u003e\n",
            "trainable_weights_type:\n",
            " \u003c\n",
            "  float32[5,5,1,32],\n",
            "  float32[32],\n",
            "  float32[5,5,32,64],\n",
            "  float32[64],\n",
            "  float32[3136,512],\n",
            "  float32[512],\n",
            "  float32[512,10],\n",
            "  float32[10]\n",
            "\u003e\n",
            "tf_dataset_type:\n",
            " \u003c\n",
            "  float32[?,28,28,1],\n",
            "  int32[?]\n",
            "\u003e*\n",
            "type signature of `initialize`:\n",
            " ( -\u003e \u003c\n",
            "  trainable_weights=\u003c\n",
            "    float32[5,5,1,32],\n",
            "    float32[32],\n",
            "    float32[5,5,32,64],\n",
            "    float32[64],\n",
            "    float32[3136,512],\n",
            "    float32[512],\n",
            "    float32[512,10],\n",
            "    float32[10]\n",
            "  \u003e,\n",
            "  optimizer_state=\u003c\n",
            "    learning_rate=float32,\n",
            "    momentum=float32,\n",
            "    accumulator=\u003c\n",
            "      float32[5,5,1,32],\n",
            "      float32[32],\n",
            "      float32[5,5,32,64],\n",
            "      float32[64],\n",
            "      float32[3136,512],\n",
            "      float32[512],\n",
            "      float32[512,10],\n",
            "      float32[10]\n",
            "    \u003e\n",
            "  \u003e\n",
            "\u003e@SERVER)\n",
            "type signature of `next`:\n",
            " (\u003c\n",
            "  server_state=\u003c\n",
            "    trainable_weights=\u003c\n",
            "      float32[5,5,1,32],\n",
            "      float32[32],\n",
            "      float32[5,5,32,64],\n",
            "      float32[64],\n",
            "      float32[3136,512],\n",
            "      float32[512],\n",
            "      float32[512,10],\n",
            "      float32[10]\n",
            "    \u003e,\n",
            "    optimizer_state=\u003c\n",
            "      learning_rate=float32,\n",
            "      momentum=float32,\n",
            "      accumulator=\u003c\n",
            "        float32[5,5,1,32],\n",
            "        float32[32],\n",
            "        float32[5,5,32,64],\n",
            "        float32[64],\n",
            "        float32[3136,512],\n",
            "        float32[512],\n",
            "        float32[512,10],\n",
            "        float32[10]\n",
            "      \u003e\n",
            "    \u003e\n",
            "  \u003e@SERVER,\n",
            "  federated_dataset={\u003c\n",
            "    float32[?,28,28,1],\n",
            "    int32[?]\n",
            "  \u003e*}@CLIENTS\n",
            "\u003e -\u003e \u003c\n",
            "  trainable_weights=\u003c\n",
            "    float32[5,5,1,32],\n",
            "    float32[32],\n",
            "    float32[5,5,32,64],\n",
            "    float32[64],\n",
            "    float32[3136,512],\n",
            "    float32[512],\n",
            "    float32[512,10],\n",
            "    float32[10]\n",
            "  \u003e,\n",
            "  optimizer_state=\u003c\n",
            "    learning_rate=float32,\n",
            "    momentum=float32,\n",
            "    accumulator=\u003c\n",
            "      float32[5,5,1,32],\n",
            "      float32[32],\n",
            "      float32[5,5,32,64],\n",
            "      float32[64],\n",
            "      float32[3136,512],\n",
            "      float32[512],\n",
            "      float32[512,10],\n",
            "      float32[10]\n",
            "    \u003e\n",
            "  \u003e\n",
            "\u003e@SERVER)\n"
          ]
        }
      ],
      "source": [
        "# 1. Server and client optimizer to be used.\n",
        "server_optimizer = tff.learning.optimizers.build_sgdm(\n",
        "    learning_rate=0.05, momentum=0.9)\n",
        "client_optimizer = tff.learning.optimizers.build_sgdm(\n",
        "    learning_rate=0.01)\n",
        "\n",
        "# 2. Functions return initial state on server.\n",
        "@tff.tensorflow.computation\n",
        "def server_init():\n",
        "  model = model_fn()\n",
        "  trainable_tensor_specs = tf.nest.map_structure(\n",
        "        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)\n",
        "  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)\n",
        "  return ServerState(\n",
        "      trainable_weights=model.trainable_variables,\n",
        "      optimizer_state=optimizer_state)\n",
        "\n",
        "@tff.federated_computation\n",
        "def server_init_tff():\n",
        "  return tff.federated_value(server_init(), tff.SERVER)\n",
        "\n",
        "# 3. One round of computation and communication.\n",
        "server_state_type = server_init.type_signature.result\n",
        "print('server_state_type:\\n',\n",
        "      server_state_type.formatted_representation())\n",
        "trainable_weights_type = server_state_type.trainable_weights\n",
        "print('trainable_weights_type:\\n',\n",
        "      trainable_weights_type.formatted_representation())\n",
        "\n",
        "# 3-1. Wrap server and client TF blocks with `tff.tensorflow.computation`.\n",
        "@tff.tensorflow.computation(server_state_type, trainable_weights_type)\n",
        "def server_update_fn(server_state, model_delta):\n",
        "  return server_update(server_state, model_delta, server_optimizer)\n",
        "\n",
        "whimsy_model = model_fn()\n",
        "tf_dataset_type = tff.SequenceType(tff.types.tensorflow_to_type(whimsy_model.input_spec))\n",
        "print('tf_dataset_type:\\n',\n",
        "      tf_dataset_type.formatted_representation())\n",
        "@tff.tensorflow.computation(tf_dataset_type, trainable_weights_type)\n",
        "def client_update_fn(dataset, server_weights):\n",
        "  model = model_fn()\n",
        "  return client_update(model, dataset, server_weights, client_optimizer)\n",
        "\n",
        "# 3-2. Orchestration with `tff.federated_computation`.\n",
        "federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)\n",
        "federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)\n",
        "@tff.federated_computation(federated_server_type, federated_dataset_type)\n",
        "def run_one_round(server_state, federated_dataset):\n",
        "  # Server-to-client broadcast.\n",
        "  server_weights_at_client = tff.federated_broadcast(\n",
        "      server_state.trainable_weights)\n",
        "  # Local client update.\n",
        "  model_deltas = tff.federated_map(\n",
        "      client_update_fn, (federated_dataset, server_weights_at_client))\n",
        "  # Client-to-server upload and aggregation.\n",
        "  mean_model_delta = tff.federated_mean(model_deltas)\n",
        "  # Server update.\n",
        "  server_state = tff.federated_map(\n",
        "      server_update_fn, (server_state, mean_model_delta))\n",
        "  return server_state\n",
        "\n",
        "# 4. Build the iterative process for FedAvg.\n",
        "fedavg_process = tff.templates.IterativeProcess(\n",
        "    initialize_fn=server_init_tff, next_fn=run_one_round)\n",
        "print('type signature of `initialize`:\\n',\n",
        "      fedavg_process.initialize.type_signature.formatted_representation())\n",
        "print('type signature of `next`:\\n',\n",
        "      fedavg_process.next.type_signature.formatted_representation())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4UYZ3qeMLo9N"
      },
      "source": [
        "## Evaluating the algorithm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jwd9Gs0ULo9O"
      },
      "source": [
        "We evaluate the performance on a centralized evaluation dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "EdNgYoIwLo9P"
      },
      "outputs": [],
      "source": [
        "def evaluate(server_state):\n",
        "  keras_model = create_keras_model()\n",
        "  tf.nest.map_structure(\n",
        "      lambda var, t: var.assign(t),\n",
        "      keras_model.trainable_weights, server_state.trainable_weights)\n",
        "  metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "  for batch in iter(central_test_data):\n",
        "    preds = keras_model(batch[0], training=False)\n",
        "    metric.update_state(y_true=batch[1], y_pred=preds)\n",
        "  return metric.result().numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "CDarZn71G2mH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Initial test accuracy 0.06451613\n",
            "Test accuracy 0.086021505\n"
          ]
        }
      ],
      "source": [
        "server_state = fedavg_process.initialize()\n",
        "acc = evaluate(server_state)\n",
        "print('Initial test accuracy', acc)\n",
        "\n",
        "# Evaluate after a few rounds\n",
        "CLIENTS_PER_ROUND=2\n",
        "sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]\n",
        "sampled_train_data = [\n",
        "    train_data.create_tf_dataset_for_client(client)\n",
        "    for client in sampled_clients]\n",
        "for round in range(20):\n",
        "  server_state = fedavg_process.next(server_state, sampled_train_data)\n",
        "acc = evaluate(server_state)\n",
        "print('Test accuracy', acc)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "custom_federated_algorithm_with_tff_optimizers.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
